/*
 * Copyright (c) 2022. China Mobile (SuZhou) Software Technology Co.,Ltd. All rights reserved.
 * Lakehouse is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package com.chinamobile.cmss.lakehouse.api.service.impl;

import com.chinamobile.cmss.lakehouse.api.dto.SqlAbortResultBean;
import com.chinamobile.cmss.lakehouse.api.dto.SqlQueryParamBean;
import com.chinamobile.cmss.lakehouse.api.service.HiveService;
import com.chinamobile.cmss.lakehouse.api.service.LakehouseInstanceService;
import com.chinamobile.cmss.lakehouse.api.service.SqlConsoleService;
import com.chinamobile.cmss.lakehouse.common.Constants;
import com.chinamobile.cmss.lakehouse.common.dto.ExecuteSQLBean;
import com.chinamobile.cmss.lakehouse.common.dto.PageBean;
import com.chinamobile.cmss.lakehouse.common.dto.SqlConsoleSQLResult;
import com.chinamobile.cmss.lakehouse.common.dto.SqlJobQueryParamDto;
import com.chinamobile.cmss.lakehouse.common.dto.SqlJobScheduleResultDto;
import com.chinamobile.cmss.lakehouse.common.dto.SqlQueryJobItemDto;
import com.chinamobile.cmss.lakehouse.common.dto.SqlQueryResultDto;
import com.chinamobile.cmss.lakehouse.common.dto.engine.LakehouseResponse;
import com.chinamobile.cmss.lakehouse.common.enums.ClusterStatusTypeEnum;
import com.chinamobile.cmss.lakehouse.common.enums.EngineType;
import com.chinamobile.cmss.lakehouse.common.enums.HttpStatus;
import com.chinamobile.cmss.lakehouse.common.enums.Status;
import com.chinamobile.cmss.lakehouse.common.enums.TaskStatusTypeEnum;
import com.chinamobile.cmss.lakehouse.common.utils.PageInfo;
import com.chinamobile.cmss.lakehouse.common.utils.ServiceException;
import com.chinamobile.cmss.lakehouse.common.utils.SqlDruidUtil;
import com.chinamobile.cmss.lakehouse.dao.SparkFilenameTaskDao;
import com.chinamobile.cmss.lakehouse.dao.SparkSQLTaskInfoDao;
import com.chinamobile.cmss.lakehouse.dao.entity.EngineTaskInfoEntity;
import com.chinamobile.cmss.lakehouse.dao.entity.LakehouseClusterInfoEntity;
import com.chinamobile.cmss.lakehouse.dao.entity.SparkFilenameTaskEntity;
import com.chinamobile.cmss.lakehouse.dao.entity.UserEntity;
import com.chinamobile.cmss.lakehouse.service.redis.RedisOperateClient;
import com.chinamobile.cmss.lakehouse.service.sqlconsole.SQLTaskService;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

@Slf4j
@Service
public class SqlConsoleServiceImpl extends BaseServiceImpl implements SqlConsoleService {

    @Autowired
    private HiveService hiveService;

    @Autowired
    private LakehouseInstanceService lakehouseInstanceService;

    @Autowired
    private SQLTaskService sqlTaskService;

    @Autowired
    SparkSQLTaskInfoDao sparkSQLTaskInfoDao;

    @Autowired
    SparkFilenameTaskDao sparkFilenameTaskDao;

    @Autowired
    RedisOperateClient redisOperateClient;

    @Override
    public Map<String, Object> getFormattedSql(String userId, String sqlQuery, String dbType, String instance) {
        Map<String, Object> result = new HashMap<>();

        try {
            if (sqlQuery == null || "".equals(sqlQuery)) {
                putMessage(result, Status.SQL_CONSOLE_SQL_NULL_ERROR);
                return result;
            }
            // validate hive db
            if (!hiveService.validateUserAuthority(userId, instance)) {
                putMessage(result, Status.SQL_CONSOLE_NO_HIVE_PERMISSION_ERROR);
                return result;
            }
            sqlQuery = SqlDruidUtil.decodeSql(sqlQuery);
            String formattedSql = SqlDruidUtil.formattedSql(sqlQuery, dbType.toLowerCase());
            result.put(Constants.DATA_LIST, formattedSql);
            putMessage(result, Status.SUCCESS);
        } catch (Exception ex) {
            log.error("format sql {} failed:{}" + sqlQuery, ex);
            putMessage(result, Status.valueOf(ex.getMessage()));
        }
        return result;
    }

    @Override
    public Map<String, Object> executeSql(UserEntity loginUser, SqlQueryParamBean sqlQueryParamBean) {
        Map<String, Object> result = new HashMap<>();

        sqlQueryParamBean.setQuerySql(SqlDruidUtil.decodeSql(sqlQueryParamBean.getQuerySql()));

        Map<String, Object> instanceMap = lakehouseInstanceService.getInstanceList(loginUser.getUserId(), sqlQueryParamBean.getEngineType());

        List<LakehouseClusterInfoEntity> instanceDataList = (List<LakehouseClusterInfoEntity>) instanceMap.get(Constants.DATA_LIST);
        List<LakehouseClusterInfoEntity> instanceList = instanceDataList
            .stream().filter(x -> x.getInstance().equals(sqlQueryParamBean.getInstance())).collect(Collectors.toList());

        if (CollectionUtils.isEmpty(instanceList)) {
            putMessage(result, Status.LAKEHOUSE_INSTANCE_LIST_NULL_ERROR);
            return result;
        }
        if (ClusterStatusTypeEnum.FROZEN.getStatus().equals(instanceList.get(0).getStatus())) {
            putMessage(result, Status.LAKEHOUSE_INSTANCE_FROZEN_ERROR);
            return result;
        }

        // assemble execute params
        ExecuteSQLBean executeSQLBean = ExecuteSQLBean.builder()
            .sqlContext(sqlQueryParamBean.getQuerySql())
            .instance(sqlQueryParamBean.getInstance())
            .engineType(sqlQueryParamBean.getEngineType())
            .dbName(sqlQueryParamBean.getDbName())
            .submitUser(loginUser.getUserName())
            .userId(loginUser.getUserId())
            .build();

        // commit sql to execute engine
        SqlConsoleSQLResult sqlConsoleSQLResult = new SqlConsoleSQLResult();
        try {
            sqlConsoleSQLResult = sqlTaskService.dispatchSQLTask(executeSQLBean);
        } catch (ServiceException e) {
            putMessage(result, Status.SQL_CONSOLE_ENGINE_CONN_FAILED_ERROR);
        }
        if (null == sqlConsoleSQLResult) {
            putMessage(result, Status.SQL_CONSOLE_SQL_COMMIT_FAILED_ERROR);
            return result;
        }

        // assemble execute result
        List<SqlQueryJobItemDto> sqlQueryJobItemDto = sqlConsoleSQLResult.getResultsList()
            .stream().map(x -> SqlQueryJobItemDto.builder()
                .jobId(x.getTaskID())
                .result(x.getResult() == null ? null : x.getResult().toString())
                .message(x.getMessage()).build())
            .collect(Collectors.toList());
        List<String> jobIds = sqlQueryJobItemDto.stream().map(SqlQueryJobItemDto::getJobId).collect(Collectors.toList());

        // engine may not return execute result immediately (such as Spark engine).
        // In this scenarios, There will only be job ids and error messages in the result set.
        // so, we need job id collection to send job execution periodic query in order to get newest returned result (see query api).
        List<EngineTaskInfoEntity> sqlTaskInfoDtos = sparkSQLTaskInfoDao.findByTaskIdIn(jobIds);

        // convert list to map
        // build relation between taskId and sql
        Map<String, String> sqlTaskIdHashMap = sqlTaskInfoDtos
            .stream()
            .collect(Collectors.toMap(EngineTaskInfoEntity::getTaskId, EngineTaskInfoEntity::getSqlContent));

        // put sql content in sqlQueryJobItemDto
        sqlQueryJobItemDto.forEach(x -> x.setSql(sqlTaskIdHashMap.get(x.getJobId())));

        result.put(Constants.STATUS, Status.SUCCESS);
        result.put(Constants.DATA_LIST, sqlQueryJobItemDto);
        return result;
    }

    @Override
    public Map<String, Object> abortSql(List<String> jobIds) {
        Map<String, Object> result = new HashMap<>();

        List<EngineTaskInfoEntity> sqlTaskInfoDtos = sparkSQLTaskInfoDao.findByTaskIdIn(jobIds);

        // convert list to map
        // build relation between taskId and sql
        Map<String, TaskStatusTypeEnum> sqlTaskIdHashMap = sqlTaskInfoDtos
            .stream()
            .collect(Collectors.toMap(EngineTaskInfoEntity::getTaskId, x -> x.getStatusTypeEnum()));

        List<SqlAbortResultBean> sqlAbortResultItemDtoList = jobIds.stream().map(x -> {
            Boolean commitSuccess;
            String message = "";
            if (TaskStatusTypeEnum.FINISHED.equals(sqlTaskIdHashMap.get(x))) {
                commitSuccess = false;
                message = Status.SQL_CONSOLE_EXEC_SUCCESS_ABORT_ERROR.getMessage();
            } else if (TaskStatusTypeEnum.FAILED.equals(sqlTaskIdHashMap.get(x))) {
                commitSuccess = false;
                message = Status.SQL_CONSOLE_EXEC_FAILED_ABORT_ERROR.getMessage();
            } else if (TaskStatusTypeEnum.CANCELLED.equals(sqlTaskIdHashMap.get(x))) {
                // if cancel commit operation, return true directly.
                commitSuccess = true;
            } else {
                LakehouseResponse<Boolean> response = sqlTaskService.killSQLTask(x);

                if (null == response || !response.getCode().equals(HttpStatus.OK)) {
                    commitSuccess = false;
                    message = Status.SQL_CONSOLE_ABORT_ERROR.getMessage();
                } else {
                    commitSuccess = response.getData();
                }
            }

            return SqlAbortResultBean.builder()
                .jobId(x)
                .commitSuccessed(commitSuccess)
                .message(message)
                .build();
        }).collect(Collectors.toList());

        result.put(Constants.STATUS, Status.SUCCESS);
        result.put(Constants.DATA_LIST, sqlAbortResultItemDtoList);
        return result;
    }

    @Override
    public Map<String, Object> getQueryResult(String jobId, String engineType) {
        Map<String, Object> result = new HashMap<>();

        log.info("query sql task result sqlQueryJobInfoDto :{}", jobId);
        EngineTaskInfoEntity sqlTaskInfoPODto = sparkSQLTaskInfoDao.findByTaskId(jobId);

        SqlQueryResultDto sqlQueryResultDto = new SqlQueryResultDto();
        if (TaskStatusTypeEnum.FINISHED.equals(sqlTaskInfoPODto.getStatusTypeEnum())) {
            // get result from redis
            if (engineType.equals(EngineType.SPARK.getType())) {
                // Spark engineType
                sqlQueryResultDto = redisOperateClient.readAsHashByJobPrefix(sqlTaskInfoPODto.getDriverPodId());
            }
            sqlQueryResultDto.setAffectedRows(CollectionUtils.isEmpty(sqlQueryResultDto.getResultSet()) ? "0" : String.valueOf(sqlQueryResultDto.getResultSet().size()));
        }
        // Assembly dto
        sqlQueryResultDto.setQuerySql(sqlTaskInfoPODto.getSqlContent());
        sqlQueryResultDto.setStatus(sqlTaskInfoPODto.getStatusTypeEnum());
        sqlQueryResultDto.setExecuteTime(String.valueOf(sqlTaskInfoPODto.getRunTime()));
        result.put(Constants.DATA_LIST, sqlQueryResultDto);
        result.put(Constants.STATUS, Status.SUCCESS);
        return result;
    }

    private PageRequest getPageRequest(PageBean pageDto, Sort sortDto) {
        if (Objects.isNull(pageDto) || Objects.isNull(pageDto.getOffset()) || Objects.isNull(pageDto.getLimit())) {
            pageDto.setLimit(-1);
            pageDto.setOffset(-1);
            log.debug("the request pageDto is null,set to default");
        }
        return PageRequest.of(pageDto.getOffset() - 1, pageDto.getLimit(), sortDto);
    }

    @VisibleForTesting
    public Specification<EngineTaskInfoEntity> getSparkTaskInoSpecification(SqlJobQueryParamDto sqlJobQueryParamDto) {
        return new Specification<EngineTaskInfoEntity>() {
            @Override
            public Predicate toPredicate(Root<EngineTaskInfoEntity> root, CriteriaQuery<?> criteriaQuery, CriteriaBuilder cBuilder) {
                // define Predicate
                Predicate p = cBuilder.conjunction();
                if (!StringUtils.isBlank(sqlJobQueryParamDto.getTaskId())) {
                    p = cBuilder.and(p, cBuilder.like(root.get("taskId"), "%" + sqlJobQueryParamDto.getTaskId() + "%"));
                }
                if (!StringUtils.isBlank(sqlJobQueryParamDto.getTaskType())) {
                    List<String> taskTypes = Arrays.asList(sqlJobQueryParamDto.getTaskType().split(","));
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("taskType"));
                    in.value(taskTypes);
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                if (!CollectionUtils.isEmpty(sqlJobQueryParamDto.getStatus())) {
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("statusTypeEnum"));
                    in.value(sqlJobQueryParamDto.getStatus());
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                if (null != sqlJobQueryParamDto.getSubmitTimeFrom()) {
                    p = cBuilder.and(p, cBuilder.greaterThanOrEqualTo(root.get("submitTime"), sqlJobQueryParamDto.getSubmitTimeFrom()));
                }
                if (null != sqlJobQueryParamDto.getSubmitTimeTo()) {
                    p = cBuilder.and(p, cBuilder.lessThanOrEqualTo(root.get("submitTime"), sqlJobQueryParamDto.getSubmitTimeTo()));
                }
                Map<String, Object> instanceMap = lakehouseInstanceService.getInstanceNames(sqlJobQueryParamDto.getUserId(), sqlJobQueryParamDto.getEngineType());
                if (!StringUtils.isBlank(sqlJobQueryParamDto.getInstance())) {
                    // check user query permission
                    if (instanceMap != null && instanceMap.keySet().stream().anyMatch(x -> x.equals(sqlJobQueryParamDto.getInstance()))) {
                        p = cBuilder.and(p, cBuilder.like(root.get("clusterId"), "%" + sqlJobQueryParamDto.getInstance() + "%"));
                    } else {
                        log.error("userId {} does not has right to access cluster {}", sqlJobQueryParamDto.getUserId(), sqlJobQueryParamDto.getInstance());
                    }
                } else {
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("clusterId"));
                    in.value(instanceMap == null ? Collections.emptyList() : instanceMap.keySet());
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                if (StringUtils.isNotBlank(sqlJobQueryParamDto.getScheduleModel())) {
                    List<String> taskIds = StringUtils.equals(sqlJobQueryParamDto.getScheduleModel(), "MANUAL")
                        ? sparkFilenameTaskDao.findTaskIdByUserId(sqlJobQueryParamDto.getUserId()) :
                        sparkFilenameTaskDao.findByScheduleModelAndUserId(sqlJobQueryParamDto.getScheduleModel(), sqlJobQueryParamDto.getUserId());
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("taskId"));
                    in.value(taskIds);
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                if (StringUtils.isNotBlank(sqlJobQueryParamDto.getRelatedTaskId())) {
                    List<String> taskIds = sparkFilenameTaskDao.findByRelatedTaskIdAndUserId(sqlJobQueryParamDto.getRelatedTaskId(), sqlJobQueryParamDto.getUserId());
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("taskId"));
                    in.value(taskIds);
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                if (StringUtils.isNotBlank(sqlJobQueryParamDto.getRelatedTaskName())) {
                    List<String> taskIds = sparkFilenameTaskDao.findByRelatedTaskNameAndUserId(sqlJobQueryParamDto.getRelatedTaskName(), sqlJobQueryParamDto.getUserId());
                    CriteriaBuilder.In<Object> in = cBuilder.in(root.get("taskId"));
                    in.value(taskIds);
                    p = cBuilder.and(p, cBuilder.and(in));
                }
                //add engineType filter
                if (StringUtils.isNotBlank(sqlJobQueryParamDto.getEngineType())) {
                    p = cBuilder.and(p, cBuilder.equal(root.get("engineType"), sqlJobQueryParamDto.getEngineType()));
                }
                if (StringUtils.isNotBlank(sqlJobQueryParamDto.getDbName())) {
                    p = cBuilder.and(p, cBuilder.equal(root.get("dbName"), sqlJobQueryParamDto.getDbName()));
                }
                return p;
            }
        };
    }

    private Sort sortBy(String sortOrder, String sortField) {
        if (!StringUtils.isBlank(sortField) && !StringUtils.isBlank(sortOrder)) {
            return Sort.by(Sort.Direction.fromString(sortOrder), sortField);
        }
        return Sort.by(Sort.Direction.DESC, "createTime");
    }

    @Override
    public Map<String, Object> querySqlTaskInfo(SqlJobQueryParamDto sqlJobQueryParamDto) {
        Map<String, Object> result = new HashMap<>();

        log.info("query sql task info queryParamDto :{}", sqlJobQueryParamDto);
        Sort sort = sortBy(sqlJobQueryParamDto.getSortOrder(), sqlJobQueryParamDto.getSortField());
        PageRequest pageRequest = getPageRequest(sqlJobQueryParamDto.getPageDto(), sort);
        Specification<EngineTaskInfoEntity> specification = getSparkTaskInoSpecification(sqlJobQueryParamDto);
        Page<EngineTaskInfoEntity> page = sparkSQLTaskInfoDao.findAll(specification, pageRequest);
        List<EngineTaskInfoEntity> sqlTaskInfoPODtos = page.getContent();

        log.info("query spark task info page is {}", pageRequest);

        ArrayList<SqlJobScheduleResultDto> sqlTaskInfoPOs = new ArrayList<>();
        for (EngineTaskInfoEntity sqlTaskInfoPODto : sqlTaskInfoPODtos) {
            SqlJobScheduleResultDto sqlTaskInfo = new SqlJobScheduleResultDto();
            BeanUtils.copyProperties(sqlTaskInfoPODto, sqlTaskInfo);
            SparkFilenameTaskEntity sparkFilenameTaskDO = sparkFilenameTaskDao.findByTaskId(sqlTaskInfoPODto.getTaskId());
            if (null != sparkFilenameTaskDO) {
                sqlTaskInfo.setRelatedTaskName(sparkFilenameTaskDO.getRelatedTaskName());
                sqlTaskInfo.setRelatedTaskId(sparkFilenameTaskDO.getRelatedTaskId());
                sqlTaskInfo.setScheduleModel(sparkFilenameTaskDO.getScheduleModel());
                sqlTaskInfo.setSparkContent(sparkFilenameTaskDO.getConfigContent());
            }
            sqlTaskInfoPOs.add(sqlTaskInfo);
        }

        PageInfo<SqlJobScheduleResultDto> pageInfo = new PageInfo<SqlJobScheduleResultDto>();
        long total = page.getTotalElements();
        pageInfo.setTotalList(sqlTaskInfoPOs);
        pageInfo.setTotal((int) total);

        log.info("query sql task info list :{} , size :{}", sqlTaskInfoPOs, total);

        result.put(Constants.DATA_LIST, pageInfo);
        putMessage(result, Status.SUCCESS);
        return result;
    }
}
